{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# $$CatBoost\\ Feature\\ Importance\\ Tutorial$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Original source of the notebook: https://github.com/catboost/tutorials/blob/master/model_analysis/feature_importance_tutorial.ipynb*\n",
    "**Credits to the creators of the original notebook.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Weights and Biases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install W&B client\n",
    "%%bash\n",
    "git clone https://github.com/wandb/client/\n",
    "cd client\n",
    "python setup.py install\n",
    "\n",
    "# when the functionality is released on PyPI you would do the below:\n",
    "#!pip install wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Go to https://www.wandb.com/, create an account or login if you already have an account\n",
    "# Create your project and ensure you have obtained a W&B Token (after logging in and creating your project)\n",
    "\n",
    "!wandb login ${WANDB_TOKEN}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from wandb.catboost import plot_feature_importances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/neomatrix369/catboost_plot_feature_importances\" target=\"_blank\">https://app.wandb.ai/neomatrix369/catboost_plot_feature_importances</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/neomatrix369/catboost_plot_feature_importances/runs/3u8me0as\" target=\"_blank\">https://app.wandb.ai/neomatrix369/catboost_plot_feature_importances/runs/3u8me0as</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Wandb version 0.8.36 is available!  To upgrade, please run:\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m:  $ pip install wandb --upgrade\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "W&B Run: https://app.wandb.ai/neomatrix369/catboost_plot_feature_importances/runs/3u8me0as"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Here you would enter the name of your project created in the above step\n",
    "\n",
    "wandb.init(project='catboost_plot_feature_importances')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Catboost Feature Importance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Sometimes it is very important to understand which feature made the greatest contribution to the final result. To do this, the CatBoost model has a get_feature_importance method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from catboost import CatBoostRegressor, Pool\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.datasets import load_iris"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### First, let's prepare the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.59 ms, sys: 0 ns, total: 2.59 ms\n",
      "Wall time: 1.84 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "full_dataset_dict = load_iris(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_dataset_dict.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = np.array(full_dataset_dict['data'])[:1000], np.array(full_dataset_dict['data'])[:1000]\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)\n",
    "X_train = X_train.reshape(-1,1)\n",
    "X_test = X_test.reshape(-1,1)\n",
    "y_train = y_train.reshape(-1,1)\n",
    "y_test = y_test.reshape(-1,1)\n",
    "\n",
    "train_pool = Pool(X_train, label=y_train)\n",
    "test_pool = Pool(X_test, label=y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Let's train CatBoost:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:\tlearn: 1.8330124\ttest: 1.8892445\tbest: 1.8892445 (0)\ttotal: 48.3ms\tremaining: 917ms\n",
      "1:\tlearn: 1.7183943\ttest: 1.7743163\tbest: 1.7743163 (1)\ttotal: 49ms\tremaining: 441ms\n",
      "2:\tlearn: 1.6143398\ttest: 1.6673153\tbest: 1.6673153 (2)\ttotal: 195ms\tremaining: 1.11s\n",
      "3:\tlearn: 1.5106876\ttest: 1.5604859\tbest: 1.5604859 (3)\ttotal: 197ms\tremaining: 786ms\n",
      "4:\tlearn: 1.4161375\ttest: 1.4631750\tbest: 1.4631750 (4)\ttotal: 203ms\tremaining: 610ms\n",
      "5:\tlearn: 1.3273977\ttest: 1.3724472\tbest: 1.3724472 (5)\ttotal: 205ms\tremaining: 479ms\n",
      "6:\tlearn: 1.2439417\ttest: 1.2854552\tbest: 1.2854552 (6)\ttotal: 206ms\tremaining: 383ms\n",
      "7:\tlearn: 1.1689810\ttest: 1.2083937\tbest: 1.2083937 (7)\ttotal: 207ms\tremaining: 311ms\n",
      "8:\tlearn: 1.0967593\ttest: 1.1355992\tbest: 1.1355992 (8)\ttotal: 383ms\tremaining: 468ms\n",
      "9:\tlearn: 1.0306900\ttest: 1.0702588\tbest: 1.0702588 (9)\ttotal: 383ms\tremaining: 383ms\n",
      "10:\tlearn: 0.9674108\ttest: 1.0048687\tbest: 1.0048687 (10)\ttotal: 488ms\tremaining: 399ms\n",
      "11:\tlearn: 0.9107697\ttest: 0.9458671\tbest: 0.9458671 (11)\ttotal: 592ms\tremaining: 394ms\n",
      "12:\tlearn: 0.8531990\ttest: 0.8859644\tbest: 0.8859644 (12)\ttotal: 593ms\tremaining: 319ms\n",
      "13:\tlearn: 0.8009974\ttest: 0.8318166\tbest: 0.8318166 (13)\ttotal: 593ms\tremaining: 254ms\n",
      "14:\tlearn: 0.7531866\ttest: 0.7856505\tbest: 0.7856505 (14)\ttotal: 598ms\tremaining: 199ms\n",
      "15:\tlearn: 0.7072903\ttest: 0.7384994\tbest: 0.7384994 (15)\ttotal: 598ms\tremaining: 150ms\n",
      "16:\tlearn: 0.6645256\ttest: 0.6938817\tbest: 0.6938817 (16)\ttotal: 601ms\tremaining: 106ms\n",
      "17:\tlearn: 0.6255223\ttest: 0.6537334\tbest: 0.6537334 (17)\ttotal: 708ms\tremaining: 78.7ms\n",
      "18:\tlearn: 0.5875867\ttest: 0.6140484\tbest: 0.6140484 (18)\ttotal: 709ms\tremaining: 37.3ms\n",
      "19:\tlearn: 0.5518532\ttest: 0.5763208\tbest: 0.5763208 (19)\ttotal: 823ms\tremaining: 0us\n",
      "\n",
      "bestTest = 0.5763208156\n",
      "bestIteration = 19\n",
      "\n",
      "CPU times: user 801 ms, sys: 171 ms, total: 971 ms\n",
      "Wall time: 897 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "params = {'iterations': 20, 'learning_rate': 0.07, 'depth': 16, 'eval_metric': 'RMSE',\n",
    "          'od_wait': 100, 'allow_writing_files': True, 'verbose': True} #  'n_estimators': 1000', #'grow_policy': 'Lossguide'\n",
    "model = CatBoostRegressor(random_seed=1234, **params)\n",
    "best_model = model.fit(train_pool, eval_set=test_pool, use_best_model=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Catboost provides several types of feature importances. One of them is PredictionDiff: A vector with contributions of each feature to the RawFormulaVal difference for each pair of objects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Let's find two objects with incorrect labels on test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.62 ms, sys: 3.71 ms, total: 8.33 ms\n",
      "Wall time: 2.71 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "prediction = model.predict(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### As you can see, feature 25  is most important for getting the right prediction."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### W&B Feature Importance visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Wandb version 0.8.36 is available!  To upgrade, please run:\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m:  $ pip install wandb --upgrade\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go to your dashboard to see the Catboost Feature Importance graph plotted.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<wandb.viz.Visualize at 0x7f305cf3c400>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plot_feature_importances(model, feature_names=full_dataset_dict['feature_names'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## You should see a nice bar chart with the title Feature importances after training with the CatBoost model, in your W&B Dashboard on https://app.wandb.ai/[your username]/[your project name]/runs/[your run id]\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
